{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SQL Agent for Spider text-to-SQL benchmark"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates a basic SQL agent that translates natural language questions into SQL queries."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment\n",
    "\n",
    "For this demo, we use a SQLite database environment based on a standard text-to-sql benchmark called [Spider](https://yale-lily.github.io/spider). The environment provides a gym-like interface and can be used as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install spider-env\n",
    "import json\n",
    "import os\n",
    "from typing import Annotated\n",
    "\n",
    "from spider_env import SpiderEnv\n",
    "\n",
    "from autogen import ConversableAgent, UserProxyAgent, config_list_from_json\n",
    "\n",
    "gym = SpiderEnv()\n",
    "\n",
    "# Randomly select a question from Spider\n",
    "observation, info = gym.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The natural language question\n",
    "question = observation[\"instruction\"]\n",
    "print(question)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The schema of the corresponding database\n",
    "schema = info[\"schema\"]\n",
    "print(schema)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agent Implementation\n",
    "\n",
    "Using AG2, a SQL agent can be implemented with a ConversableAgent. The gym environment executes the generated SQL query and the agent can take execution results as feedback to improve its generation in multiple rounds of conversations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"AUTOGEN_USE_DOCKER\"] = \"False\"\n",
    "config_list = config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n",
    "\n",
    "\n",
    "def check_termination(msg: dict):\n",
    "    if \"tool_responses\" not in msg:\n",
    "        return False\n",
    "    json_str = msg[\"tool_responses\"][0][\"content\"]\n",
    "    obj = json.loads(json_str)\n",
    "    return \"error\" not in obj or (obj[\"error\"] is None and obj[\"reward\"] == 1)\n",
    "\n",
    "\n",
    "sql_writer = ConversableAgent(\n",
    "    \"sql_writer\",\n",
    "    llm_config={\"config_list\": config_list},\n",
    "    system_message=\"You are good at writing SQL queries. Always respond with a function call to execute_sql().\",\n",
    "    is_termination_msg=check_termination,\n",
    ")\n",
    "user_proxy = UserProxyAgent(\"user_proxy\", human_input_mode=\"NEVER\", max_consecutive_auto_reply=5)\n",
    "\n",
    "\n",
    "@sql_writer.register_for_llm(description=\"Function for executing SQL query and returning a response\")\n",
    "@user_proxy.register_for_execution()\n",
    "def execute_sql(\n",
    "    reflection: Annotated[str, \"Think about what to do\"], sql: Annotated[str, \"SQL query\"]\n",
    ") -> Annotated[dict[str, str], \"Dictionary with keys 'result' and 'error'\"]:\n",
    "    observation, reward, _, _, info = gym.step(sql)\n",
    "    error = observation[\"feedback\"][\"error\"]\n",
    "    if not error and reward == 0:\n",
    "        error = \"The SQL query returned an incorrect result\"\n",
    "    if error:\n",
    "        return {\n",
    "            \"error\": error,\n",
    "            \"wrong_result\": observation[\"feedback\"][\"result\"],\n",
    "            \"correct_result\": info[\"gold_result\"],\n",
    "        }\n",
    "    else:\n",
    "        return {\n",
    "            \"result\": observation[\"feedback\"][\"result\"],\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The agent can then take as input the schema and the text question, and generate the SQL query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message = f\"\"\"Below is the schema for a SQL database:\n",
    "{schema}\n",
    "Generate a SQL query to answer the following question:\n",
    "{question}\n",
    "\"\"\"\n",
    "\n",
    "user_proxy.initiate_chat(sql_writer, message=message)"
   ]
  }
 ],
 "metadata": {
  "front_matter": {
   "description": "Natural language text to SQL query using the Spider text-to-SQL benchmark.",
   "tags": [
    "SQL",
    "tool/function"
   ]
  },
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
